Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Code Review
This pull request introduces support for DFlash speculative decoding, a new method that leverages bidirectional attention. The implementation is comprehensive, touching configuration, model definition, the speculative proposer, and the core model runner. Key changes include:
- A new
DFlashProposerandqwen3_dflashmodel to implement the DFlash architecture. - Refactoring of
eagle.pyto better support different speculative decoding methods. - Configuration updates for auto-detection and setup of DFlash.
- Extensive end-to-end tests for both correctness and acceptance rate, which is great to see.
The code is well-structured, and the refactoring improves extensibility. My main concern, which you've also highlighted in the PR description and code comments, is the compatibility with CUDA graphs due to the torch.cat operation on potentially differently padded tensors. This is a critical point for performance and stability, and I've left a comment with a suggestion on how to address it.
|
Update on the graphs, I have a local workaround that I'm working on cleaning up. The solution is to put the context states in the forward_context and access them via CustomOp similar to unified_attention_with_output. Then all the ordinary logic can go in the main graph. |
|
This pull request has merge conflicts that must be resolved before it can be |
|
Thanks @benchislett for implementing DFlash in vLLM. We’ve just released DFlash checkpoints for Qwen3.5-4B, 9B, 27B, and 35B-A3B. They perform very well and are faster than MTP. Feel free to try them out. We’ll continue adding support for more models, and I’m really looking forward to seeing DFlash run in vLLM! |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
27773c5 to
d9a63c2
Compare
|
(force-pushed to fix DCO. manually checked all test cases still pass) |
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
|
Regarding this issue: It will now raise a clear error on startup as of e99905a. |
| gpu_memory_utilization=0.85, | ||
| enforce_eager=False, | ||
| disable_log_stats=False, | ||
| attention_config={"backend": "FLASH_ATTN"}, # Required for non-causal attention |
There was a problem hiding this comment.
Can we get this working across the board without needing to specify this arg? We should be able to resolve this internally by querying the attention backend's supports_attn_type during selection
| def use_eagle(self) -> bool: | ||
| return self.method in ("eagle", "eagle3", "mtp") | ||
| return self.method in ("eagle", "eagle3", "mtp", "dflash") | ||
|
|
||
| def use_dflash(self) -> bool: | ||
| return self.method == "dflash" |
There was a problem hiding this comment.
nit: what is the point of use_eagle being used for several methods? maybe we should rename this to a more general pattern like uses_hidden_state_proposer
| normed_context_states = rmsnorm( | ||
| input=context_states, | ||
| weight=self._hidden_norm_weight, | ||
| eps=self._rms_norm_eps, | ||
| ) |
There was a problem hiding this comment.
What is the purpose of using the flashinfer rmsnorm? I'd prefer to have this use the general rmsnorm op in vLLM, and the flashinfer version should be a backend within it.
There was a problem hiding this comment.
Agree, most of this is due to my own ignorance about vLLM's kernel internals. It should be feasible to dispatch to vLLM's fused RMSNorm here. For the RoPE though, I'm not sure if we have a native fused kernel. I can look into it.
There was a problem hiding this comment.
Definitely don't need the dependency on FlashInfer long-term. But was very easy in the prototype. I'll clean this up
There was a problem hiding this comment.
I think it's certainly okay to land as-is as long as we remove the global flashinfer imports
| # In-place RoPE cannot be called here, since we use K for both query and key. | ||
| # Instead we just call the fused kernel and ignore the query output. | ||
| _, all_k_flat = apply_rope_with_cos_sin_cache( | ||
| positions=positions_repeated, | ||
| query=all_k_flat, | ||
| key=all_k_flat, | ||
| head_size=self._rope_head_size, | ||
| cos_sin_cache=self._rope_cos_sin_cache, | ||
| is_neox=self._rope_is_neox, | ||
| ) |
There was a problem hiding this comment.
ditto here, but I understand based on the comment if this is more required. I think other than these two flashinfer kernels, everything else should not necessarily have a dependency on NVIDA GPUs. If this must be kept, we should gate this path on a CUDA platform check
mgoin
left a comment
There was a problem hiding this comment.
Really really nice work! I think these are all the things I found for now, but I should take another look through soon
| from flashinfer import rmsnorm | ||
| from flashinfer.rope import apply_rope_with_cos_sin_cache |
There was a problem hiding this comment.
Ditto on CUDA platform check/lazy import when needed instead of unconditionally putting flashinfer import at the top of a model file
| num_layers = len(parent_ref.model.layers) | ||
| return (2, num_layers // 2, num_layers - 3) | ||
|
|
||
| def _get_default_eagle3_aux_hidden_state_layers(self) -> tuple[int]: |
There was a problem hiding this comment.
nit: should be tuple[int, ...] since i think it can be variable in length
| # Save context slot_mapping for pre-insertion (clone because buffer | ||
| # will be reused by _get_slot_mapping) | ||
| self._dflash_context_slot_mapping = self._slot_mapping_buffer[ | ||
| :num_context | ||
| ].clone() | ||
|
|
||
| # Only query slot_mapping for the model forward pass — context KVs | ||
| # are pre-inserted into cache before the forward. Clone to avoid | ||
| # aliasing with the buffer that _get_slot_mapping writes into. | ||
| query_slot_mapping = self._slot_mapping_buffer[ | ||
| num_context:num_all_positions | ||
| ].clone() |
There was a problem hiding this comment.
If this clones are a perf issue, I think it would be straightforward to double buffer
| tl.store(out_positions_ptr + pos_out_idx, positions, mask=in_bounds) | ||
|
|
||
| # --- Slot mapping (block_table lookup for all positions) --- | ||
| block_num = positions // block_size |
There was a problem hiding this comment.
I believe this is needs the same clamping used in the eagle kernel above
# Block table lookup: block_number = position // block_size
# Clamp block_number to avoid OOB when position is at max
block_number = clamped_position // block_size
block_number = tl.minimum(block_number, n_blocks_per_req - 1)| from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM | ||
| from vllm.model_executor.models.interfaces import SupportsMultiModal | ||
| from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM | ||
| from vllm.model_executor.models.qwen3_dflash import DFlashQwen3ForCausalLM |
There was a problem hiding this comment.
Yeah we should remove the top-level flashinfer import if we need to import the model class like this
|
This pull request has merge conflicts that must be resolved before it can be |
|
hi @benchislett do you have an image of this branch for testing? would like to test it with gpt oss 20b on rtx 6000 pro. |
Purpose
Overview
DFlash works much like P-EAGLE (see #32887), but with a major architectural change: it uses bidirectional attention between the query tokens (the last sampled token from the base model plus a bunch of placeholder mask tokens) and the context states, which are the target model's hidden states from the prefill or accepted tokens.
To implement this, I introduce an extra operation that lives outside of the main model execution, which populates the KV cache with the context states directly. Though not exposed to the standard set of torch.compile and CUDA graph optimizations, handling the context in this way allows us to use async scheduling (by writing the full set of context states to the cache and then using seq_lens_gpu to ignore the rejected ones), as well as enabling (piecewise) CUDA graphs for the main forward pass over the query tokens.
(Solved)
Because the attention is bidirectional and structured in this way, we cannot allow rejected tokens to stay in the batch as their states would corrupt the rest of the calculation, unlike in standard causal attention where we can simply omit them and sample from an earlier position. Therefore, "disable_padded_drafter_batch" is required and will be enabled when using DFlash, disabling async scheduling as a consequence. At this time I cannot think of a good way around this problem.Additionally, a small selection of kernel backends actually support non-causal attention, and a smaller set additionally include support for gpt-oss style "sinks". Flash Attention with Qwen3-8B is used as a test for this implementation, as Triton Attention and FlashInfer (TRTLLM) attention both do not support non-causality. A follow-up work here would be to allow different attention backends for the drafter and the target model. This is out-of-scope for this PR, but would enable a broader set of compatibility for DFlash models.
(Solved)
Finally, this new architecture requires extra logic to handle the new input shapes and attention metadata. It is not as simple as P-EAGLE in which we can share code with other modes of speculation, since the sizes and contents of the input tensors are all different. One compatibility component in which I am not absolutely confident is the CUDA graph support: even in piecewise mode, I am not sure how the "padded" query tokens interact with the unpadded context: is it safe to have operations on the context slice of Q inside the attention op, or must it be moved into a custom op to avoid issues? It has not yet been functional to enable torch.compile for the DFlash drafter, likely for this reason.Implementation Details
The model is implemented in
qwen3_dflash.pyand the core speculation logic is indflash.py. Similarly to draft model speculative decoding, DFlash has a unique speculation paradigm that requires some refactoring ofeagle.pyin order to support cleanly.Specifically,
build_model_inputs_first_passandbuild_per_layer_attn_metadatahave been introduced ineagle.pyso thatdflash.pycan override them.Initial implementations inlined the logic for DFlash into
eagle.py, but similarly to #24322 the added branching and complexity would (in my opinion) lead to a fairly cluttered EAGLE implementation. In this PR I have tried to separate the concerns to a reasonable degree, so that maintenance of the existing EAGLE pathway is not burdened.Testing
I add unit tests for both correctness (GSM8k) and acceptance rate checks for Qwen3-8B with DFlash. I am able to reproduce almost exactly the acceptance rate values reported in the DFlash paper, and have included this in the test suite with an exact reproduction of their evaluation setup. This should be a valuable resource and an easily extensible tool to new DFlash models as they continue to evolve and expand the DFlash model family.
In local tests on 1xB200, both Qwen3-8B and Qwen3.5-9B pass the test suite using both FA2 and FA4 on 1xB200.
Usage
Benchmarking
The latest DFlash implementation on this branch is optimized for low-latency performance. I measured the following speedups using the DFlash methodology (https://github.com/z-lab/dflash)
Datasets covered are Alpaca, GSM8k, HumanEval, Math500, MBPP, MT-Bench. Number of spec tokens for DFlash is 15.
Alpaca
GSM8k
HumanEval
Math500
MBPP
MT-Bench